{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SQL translators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The purpose of this vignette is to walk through how expressions like `_.id.mean()` are converted into SQL.\n",
    "\n",
    "This process involves 3 parts\n",
    "\n",
    "1. SQL translation functions, e.g. taking column \"id\" and producing the SQL \"ROUND(id)\".\n",
    "2. SQL translation from a symbolic call\n",
    "  - Converting method calls like `_.id.round(2)` to `round(_.id, 2)`\n",
    "  - Looking up SQL translators (e.g. for \"mean\" function call)\n",
    "3. Handling SQL partitions, like in OVER clauses\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using sqlalchemy select statement for convenience\n",
    "\n",
    "Throughout this vignette, we'll use a select statement object from sqlalchemy,\n",
    "so we can conveniently access its columns as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT id, x, y\n",
      "<class 'sqlalchemy.sql.base.ImmutableColumnCollection'>\n",
      "['id', 'x', 'y']\n"
     ]
    }
   ],
   "source": [
    "from sqlalchemy import sql\n",
    "col_names = ['id', 'x', 'y']\n",
    "sel = sql.select([sql.column(x) for x in col_names])\n",
    "\n",
    "print(sel)\n",
    "print(type(sel.columns))\n",
    "print(sel.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Translator functions\n",
    "\n",
    "A SQL translator function takes...\n",
    "\n",
    "* a first argument that is a sqlalchemy Column\n",
    "* (optional) additional arguments for the translation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A simple translator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "round(x, :round_1)\n"
     ]
    }
   ],
   "source": [
    "f_simple_round = lambda col, n: sql.func.round(col, n)\n",
    "\n",
    "sql_expr = f_simple_round(sel.columns.x, 2)\n",
    "\n",
    "print(sql_expr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function above is essentially what most translator functions are.\n",
    "\n",
    "For example, here is the round function defined for postgresql.\n",
    "One key difference is that it casts the column to a numeric beforehand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "round(CAST(x AS NUMERIC), :round_1)\n"
     ]
    }
   ],
   "source": [
    "from siuba.sql.dialects.postgresql import funcs\n",
    "\n",
    "f_round = funcs['scalar']['round']\n",
    "sql_expr = f_round(sel.columns.x, 2)\n",
    "\n",
    "print(sql_expr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Handling windows with custom Over clauses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'siuba.sql.translate.AggOver'>\n",
      "avg(x) OVER ()\n"
     ]
    }
   ],
   "source": [
    "f_win_mean = funcs['window']['mean']\n",
    "\n",
    "sql_over_expr = f_win_mean(sel.columns.x)\n",
    "\n",
    "print(type(sql_over_expr))\n",
    "print(sql_over_expr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that this window expression has an empty over clause. This clause needs to be able to include any variables we've grouped the data by.\n",
    "\n",
    "Siuba handles this by implementing a `set_over` method on these custom sqlalchemy Over clauses, which takes grouping and ordering variables as arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avg(x) OVER (PARTITION BY x, y)\n"
     ]
    }
   ],
   "source": [
    "group_by_clause = sql.elements.ClauseList(sel.columns.x, sel.columns.y)\n",
    "print(sql_over_expr.set_over(group_by_clause))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Call shaping\n",
    "\n",
    "The section above discusses how SQL translators are functions that take a sqlalchemy column, and return a SQL expression. However, when using siuba we often have expressions like...\n",
    "\n",
    "```\n",
    "mutate(data, x = _.y.round(2))\n",
    "```\n",
    "\n",
    "In this case, before we can even use a SQL translator, we need to...\n",
    "\n",
    "* find the name and arguments of the method being called\n",
    "* find the column it is being called on\n",
    "\n",
    "This is done by using the `CallTreeLocal` class to analyze the tree of operations for each expression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "█─'__call__'\n",
       "├─█─.\n",
       "│ ├─█─.\n",
       "│ │ ├─_\n",
       "│ │ └─'y'\n",
       "│ └─'round'\n",
       "└─2"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from siuba.siu import Lazy, CallTreeLocal, Call, strip_symbolic\n",
    "from siuba import _\n",
    "\n",
    "_.y.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example of translation with CallTreeLocal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from siuba.sql.dialects.postgresql import funcs\n",
    "\n",
    "local_funcs = {**funcs['scalar'], **funcs['window']}\n",
    "\n",
    "call_shaper = CallTreeLocal(\n",
    "    local_funcs,\n",
    "    call_sub_attr = ('dt',)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_.id.mean()\n"
     ]
    }
   ],
   "source": [
    "symbol = _.id.mean()\n",
    "call = strip_symbolic(symbol)\n",
    "print(call)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avg(id) OVER ()\n"
     ]
    }
   ],
   "source": [
    "func_call = call_shaper.enter(call)\n",
    "print(func_call(sel.columns))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the same result as when we called the SQL translator for `mean` manually!\n",
    "In that section we also showed that we can set group information, so that it takes \n",
    "an average within each group.\n",
    "\n",
    "In this case it's easy to set group information to the Over clause.\n",
    "However, an additional challenge is when it's part of a larger expression..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<sqlalchemy.sql.elements.BinaryExpression object at 0x11889ada0>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "call2 = strip_symbolic(_.id.mean() + 1)\n",
    "func_call2 = call_shaper.enter(call2)\n",
    "\n",
    "func_call2(sel.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Handling partitions\n",
    "\n",
    "While the first section showed how siuba's custom Over clauses can add grouping info to a translation, it is missing one key detail: expressions that generate Over clauses, like `_.id.mean()`, can be part of larger expressions. For example `_.id.mean() + 1`.\n",
    "\n",
    "In this case, if we look at the call tree for that expression, the top operation is the addition..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "█─+\n",
       "├─█─'__call__'\n",
       "│ └─█─.\n",
       "│   ├─█─.\n",
       "│   │ ├─_\n",
       "│   │ └─'id'\n",
       "│   └─'mean'\n",
       "└─1"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_.id.mean() + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How can we create the appropriate expression...\n",
    "\n",
    "```\n",
    "avg(some_col) OVER (PARTITION BY x, y) + 1\n",
    "```\n",
    "\n",
    "when the piece that needs grouping info is not easily accessible? The answer is by using a tree visitor, which steps down every black rectangle in the call tree shown above, from top to bottom.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full example\n",
    "\n",
    "Below, we copy the code from the call shaping section.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from siuba.sql.verbs import track_call_windows\n",
    "from siuba import _\n",
    "from siuba.sql.dialects.postgresql import funcs\n",
    "\n",
    "local_funcs = {**funcs['scalar'], **funcs['window']}\n",
    "\n",
    "call_shaper = CallTreeLocal(\n",
    "    local_funcs,\n",
    "    call_sub_attr = ('dt',)\n",
    "    )\n",
    "\n",
    "symbol3 = _.id.mean() + 1\n",
    "call3 = strip_symbolic(symbol3)\n",
    "func_call3 = call_shaper.enter(call3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we pass the shaped call..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avg(id) OVER (PARTITION BY x, y) + :param_1\n",
      "[<sqlalchemy.sql.elements.Label object at 0x11889ec50>, <sqlalchemy.sql.elements.Label object at 0x11889e1d0>]\n",
      "SELECT id, x, y, avg(id) OVER (PARTITION BY x, y) AS win1, avg(id) OVER (PARTITION BY x, y) + :param_1 AS win2 \n",
      "FROM (SELECT id, x, y)\n"
     ]
    }
   ],
   "source": [
    "col, windows, window_cte = track_call_windows(\n",
    "    func_call3,\n",
    "    sel.columns,\n",
    "    group_by = ['x', 'y'],\n",
    "    order_by = [],\n",
    "    # note that this is optional, and results in window_cte being a\n",
    "    # copy of this select that contains the window clauses\n",
    "    window_cte = sel.select()\n",
    "    )\n",
    "\n",
    "print(col)\n",
    "print(windows)\n",
    "print(window_cte)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
